#!/usr/bin/env python3

"""
Ensemble networks that are trained to
Version 1: predict the next state.
Version 2: predict the average return.
"""

import torch
import random
from typing import List
import wandb
import math
import numpy as np
import torch
import einops
from rpi import logger
from rpi.agents.mamba import Ensemble, ActiveStateExplorer,ValueEnsemble
from rpi.helpers.factory import Factory
from rpi.helpers import to_torch
from . import NewStateDetector, extract_states
from .distance_based_methods import EuclideanDistance, WassersteinDistance


def get_state_pred_ensemble(state_dim, num_state_nns: int = 5, device = None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    make_nn = lambda: Factory.create_state_nn(state_dim)
    state_pred_ensemble = Ensemble(make_nn, num_state_nns, state_dim, state_dim)
    state_pred_ensemble.to(device)
    return state_pred_ensemble


class StatePredEnsembleNewStateDetector(NewStateDetector):
    def __init__(self, state_dim, deviation_estimator: str = 'std'):
        self.state_pred_ensemble = get_state_pred_ensemble(state_dim)
        self.optimizer = torch.optim.Adam(self.state_pred_ensemble.parameters())
        self.deviation_estimator = deviation_estimator
        assert deviation_estimator in ['max_dist', 'std_from_mean', 'std']

    def experience(self, all_episodes: List):
        """Take pairs of (state, next-state), and train the ensemble model.
        This assumes a single policy is used during data collection.
        """
        states = [extract_states(ep) for ep in all_episodes]
        all_states = np.concatenate(states)

        # Get indices of states that have next state (i.e., only exclude the last state)
        state_inds = []
        last_idx = 0
        for _states in states:
            # -1 to exclude the last state
            state_inds += list(range(last_idx, last_idx + _states.shape[0] - 1))
            last_idx = state_inds[-1]

        # Create an empty array that stores all states
        # num_states = sum([_states.shape[0] for _states in states])
        # state_dim = states[0].shape[1]
        # all_states = np.zeros((num_states, state_dim), dtype=np.float32)

        # Get a batch of states -> next_states
        num_epochs = 100
        batch_size = 8192
        num_batches = math.ceil(len(state_inds) / batch_size)
        step = 0
        for epoch in range(num_epochs):
            # Shuffle state_inds in-place
            random.shuffle(state_inds)

            for batch_idx in range(num_batches):
                batch_inds = np.asarray(
                    state_inds[batch_size * batch_idx: batch_size * (batch_idx + 1)]
                )
                batch_states = to_torch(all_states[batch_inds])
                batch_next_states = to_torch(all_states[batch_inds + 1])

                # Feed these to the network, train step
                distrs = self.state_pred_ensemble.forward_all(batch_states)
                if self.deviation_estimator in ['max_dist', 'std_from_mean']:
                    _losses = [torch.mean((distr.mean - batch_next_states) ** 2) for distr in distrs]
                elif self.deviation_estimator == 'std':
                    _losses = [-distr.log_prob(batch_next_states) for distr in distrs]
                else:
                    raise ValueError(f'Unknown value: {self.deviation_estimator}')

                loss = (sum(_losses) / len(_losses)).mean()

                wandb.log({'train/loss': loss.item(), 'train/step': step})

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                step += 1

            logger.info(f'epoch: {epoch + 1} / {num_epochs}\tloss:{loss.item()}')

    def batch_evaluate(self, batch_states: np.ndarray):
        batch_states = to_torch(batch_states)
        if self.deviation_estimator == 'max_dist':
            distrs = self.state_pred_ensemble.forward_all(batch_states)
            # means.shape: (batch_size, num_ensembles, state_dim)
            means = torch.stack([distr.mean for distr in distrs], dim=1).detach()
            means_a = einops.repeat(means, 'b n d -> b n i d', i=means.shape[1])
            means_b = einops.repeat(means, 'b n d -> b j n d', j=means.shape[1])
            diff = torch.linalg.norm(means_a - means_b, dim=-1)
            diff_max = torch.amax(diff, dim=(1, 2))
            batch_discrepancy = diff_max
        elif self.deviation_estimator == 'std_from_mean':
            distrs = self.state_pred_ensemble.forward_all(batch_states)
            means = torch.stack([distr.mean for distr in distrs], dim=1).detach()
            var = torch.var(means, dim=1)
            std = torch.sqrt(var)
            batch_discrepancy = std.detach().mean(dim=-1)
        elif self.deviation_estimator == 'std':
            pred_next_states_distrs = self.state_pred_ensemble.forward_stats(batch_states)
            # pred_next_states_distrs.std.shape: (batch_size, state_dim)
            batch_discrepancy = pred_next_states_distrs.std.detach().mean(dim=-1)
        else:
            raise ValueError(f'Unknown value: {self.deviation_estimator}')

        return batch_discrepancy.cpu().numpy()

    def batch_evaluate_target_expert(self, batch_states: np.ndarray, expert_id:float):
        batch_states = to_torch(batch_states)
        if self.deviation_estimator == 'max_dist':
            distrs = self.state_pred_ensemble.forward_all(batch_states)
            # means.shape: (batch_size, num_ensembles, state_dim)
            means = torch.stack([distr.mean for distr in distrs], dim=1).detach()
            means_a = einops.repeat(means, 'b n d -> b n i d', i=means.shape[1])
            means_b = einops.repeat(means, 'b n d -> b j n d', j=means.shape[1])
            diff = torch.linalg.norm(means_a - means_b, dim=-1)
            diff_max = torch.amax(diff, dim=(1, 2))
            batch_discrepancy = diff_max

        elif self.deviation_estimator == 'std_from_mean':
            distrs = self.state_pred_ensemble.forward_all(batch_states)
            means = torch.stack([distr.mean for distr in distrs], dim=1).detach()
            var = torch.var(means, dim=1)
            std = torch.sqrt(var)
            batch_discrepancy = std.detach().mean(dim=-1)
            
        elif self.deviation_estimator == 'std':
            pred_next_states_distrs = self.state_pred_ensemble.forward_stats(batch_states)
            # pred_next_states_distrs.std.shape: (batch_size, state_dim)
            std=pred_next_states_distrs.std.detach()

            # batch_discrepancy = pred_next_states_distrs.std.detach().mean(dim=-1)
            batch_discrepancy = pred_next_states_distrs.std.detach()[:,expert_id] #for f_max expert only

        else:
            raise ValueError(f'Unknown value: {self.deviation_estimator}')

        return batch_discrepancy.cpu().numpy()


#Next state prediction approach
# from rpi.state_detection.ensemble_network import StatePredEnsembleNewStateDetector
class NextStatePredActiveStateExplorer(ActiveStateExplorer):
    def __init__(self, value_fns: List[ValueEnsemble], state_pred_ensemble: StatePredEnsembleNewStateDetector, all_episodes: List, sigma: float,deviation_estimator='std') -> None:

        self.value_fns = value_fns
        self.state_pred_ensemble = state_pred_ensemble(state_dim=len(all_episodes[0][0][0]['state']))
        self.state_pred_ensemble.experience(all_episodes)
        self.sigma = sigma
        self._all_episodes = all_episodes

    @torch.no_grad()
    def should_explore(self, obs):
        obs = to_torch(obs).unsqueeze(0)
        best_idx, best_valobj = self._get_best_expert(obs)
        if isinstance(obs, torch.Tensor):
            obs = obs.cpu().detach().numpy()
        else:
            obs= obs

        batch_obs = np.asarray(obs)
        batch_discrepancy = self.state_pred_ensemble.batch_evaluate_target_expert(batch_obs,best_idx)

        return batch_discrepancy < self.sigma, best_idx, best_valobj, batch_discrepancy #return best_valobj or batch_discrepancy


#Euclain distance approach
class EuclainDistanceActiveStateExplorer(ActiveStateExplorer):
    def __init__(self, value_fns: List[ValueEnsemble], euclain_distance: EuclideanDistance, all_episodes: List, sigma: float,deviation_estimator='std') -> None:

        self.value_fns = value_fns
        self.euclain_distance = euclain_distance(min_n=2, threshold_accept=0, option="mean_dist", debug=False)
        self.sigma = sigma
        self.all_episodes = all_episodes

    @torch.no_grad()
    def should_explore(self, obs):
        obs = to_torch(obs).unsqueeze(0)
        best_idx, best_valobj = self._get_best_expert(obs)
        # batch_obs = np.asarray([obs])
        # print("bext_idx",best_idx)
        # print("self.all_episodes",len(self.all_episodes[0]))
        # print("self.all_episodes",dir(self.all_episodes[0]))
        # exit()
        # eu_distance = self.euclain_distance.evaluate(obs)
        eu_distance = self.euclain_distance.evaluate(self.all_episodes[0][best_idx],obs)

        return eu_distance < self.sigma, best_idx, best_valobj,eu_distance #return best_valobj or batch_discrepancy



#wasserstein distance approach
class WassersteinDistanceActiveStateExplorer(ActiveStateExplorer):
    def __init__(self, value_fns: List[ValueEnsemble], ws_distance: WassersteinDistance, all_episodes: List, sigma: float,deviation_estimator='std') -> None:

        self.value_fns = value_fns
        self._ws_distance = ws_distance(min_n=2, threshold_accept=0, option="mean_dist", debug=False)
        self.sigma = sigma
        self.all_episodes = all_episodes

    @torch.no_grad()
    def should_explore(self, obs):
        obs = to_torch(obs).unsqueeze(0)
        best_idx, best_valobj = self._get_best_expert(obs)
        # batch_obs = np.asarray([obs])
        ws_distance = self._ws_distance.evaluate(self.all_episodes[0][best_idx],obs)

        return ws_distance < self.sigma, best_idx, best_valobj,ws_distance #return best_valobj or batch_discrepancy
